# Copyright Huawei Technologies Co., Ltd. 2025. All rights reserved.
from dataclasses import dataclass

from typing_extensions import Self

from msmodelslim.app.base.quant_config import BaseQuantConfig
from msmodelslim.utils.exception import SchemaValidateError


@dataclass
class QuantSpec:
    # anti
    anti_cfg: dict = None # anti-outlier config
    anti_params: dict = None # anti-outlier params

    # calib
    calib_cfg: dict = None # calib config
    calib_params: dict = None # calib params
    calib_save_params: dict = None # calib save params

    # quantization parameters
    batch_size: int = 4 # batch size
    anti_dataset: str = None # anti-outlier dataset
    calib_dataset: str = None # calib dataset


@dataclass
class ModelslimV0QuantConfig(BaseQuantConfig):
    spec: QuantSpec # quantization config specification

    @classmethod
    def from_base(cls, quant_config: BaseQuantConfig) -> Self:
        return cls(
            apiversion=quant_config.apiversion,
            metadata=quant_config.metadata,
            spec=load_specific_config(quant_config.spec),
        )


def load_specific_config(yaml_spec: object) -> QuantSpec:
    """Load specific configuration from YAML spec"""
    if not isinstance(yaml_spec, dict):
        raise SchemaValidateError("task spec must be dict",
                                  action='Please make sure the task spec is a dictionary')

    config = QuantSpec()
    config.anti_cfg = yaml_spec.get('anti_cfg', None)
    config.anti_params = yaml_spec.get('anti_params', {})
    config.calib_cfg = yaml_spec.get('calib_cfg', {})
    config.calib_params = yaml_spec.get('calib_params', {})
    config.calib_save_params = yaml_spec.get('calib_save_params', {})
    config.batch_size = yaml_spec.get('batch_size', 4)
    config.anti_dataset = yaml_spec.get('anti_dataset', None)
    config.calib_dataset = yaml_spec.get('calib_dataset', 'teacher_qualification.jsonl')
    return config
